{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 子词嵌入\n",
    ":label:`sec_fasttext`\n",
    "\n",
    "在英语中，“helps”、“helped”和“helping”等单词都是同一个词“help”的变形形式。“dog”和“dogs”之间的关系与“cat”和“cats”之间的关系相同，“boy”和“boyfriend”之间的关系与“girl”和“girlfriend”之间的关系相同。在法语和西班牙语等其他语言中，许多动词有40多种变形形式，而在芬兰语中，名词最多可能有15种变形。在语言学中，形态学研究单词形成和词汇关系。但是，word2vec和GloVe都没有对词的内部结构进行探讨。\n",
    "\n",
    "## fastText模型\n",
    "\n",
    "回想一下词在word2vec中是如何表示的。在跳元模型和连续词袋模型中，同一词的不同变形形式直接由不同的向量表示，不需要共享参数。为了使用形态信息，*fastText模型*提出了一种*子词嵌入*方法，其中子词是一个字符$n$-gram :cite:`Bojanowski.Grave.Joulin.ea.2017`。fastText可以被认为是子词级跳元模型，而非学习词级向量表示，其中每个*中心词*由其子词级向量之和表示。\n",
    "\n",
    "让我们来说明如何以单词“where”为例获得fastText中每个中心词的子词。首先，在词的开头和末尾添加特殊字符“&lt;”和“&gt;”，以将前缀和后缀与其他子词区分开来。\n",
    "然后，从词中提取字符$n$-gram。\n",
    "例如，值$n=3$时，我们将获得长度为3的所有子词：\n",
    "“&lt;wh”、“whe”、“her”、“ere”、“re&gt;”和特殊子词“&lt;where&gt;”。\n",
    "\n",
    "在fastText中，对于任意词$w$，用$\\mathcal{G}_w$表示其长度在3和6之间的所有子词与其特殊子词的并集。词表是所有词的子词的集合。假设$\\mathbf{z}_g$是词典中的子词$g$的向量，则跳元模型中作为中心词的词$w$的向量$\\mathbf{v}_w$是其子词向量的和：\n",
    "\n",
    "$$\\mathbf{v}_w = \\sum_{g\\in\\mathcal{G}_w} \\mathbf{z}_g.$$\n",
    "\n",
    "fastText的其余部分与跳元模型相同。与跳元模型相比，fastText的词量更大，模型参数也更多。此外，为了计算一个词的表示，它的所有子词向量都必须求和，这导致了更高的计算复杂度。然而，由于具有相似结构的词之间共享来自子词的参数，罕见词甚至词表外的词在fastText中可能获得更好的向量表示。\n",
    "\n",
    "## 字节对编码（Byte Pair Encoding）\n",
    ":label:`subsec_Byte_Pair_Encoding`\n",
    "\n",
    "在fastText中，所有提取的子词都必须是指定的长度，例如$3$到$6$，因此词表大小不能预定义。为了在固定大小的词表中允许可变长度的子词，我们可以应用一种称为*字节对编码*（Byte Pair Encoding，BPE）的压缩算法来提取子词 :cite:`Sennrich.Haddow.Birch.2015`。\n",
    "\n",
    "字节对编码执行训练数据集的统计分析，以发现单词内的公共符号，诸如任意长度的连续字符。从长度为1的符号开始，字节对编码迭代地合并最频繁的连续符号对以产生新的更长的符号。请注意，为提高效率，不考虑跨越单词边界的对。最后，我们可以使用像子词这样的符号来切分单词。字节对编码及其变体已经用于诸如GPT-2 :cite:`Radford.Wu.Child.ea.2019`和RoBERTa :cite:`Liu.Ott.Goyal.ea.2019`等自然语言处理预训练模型中的输入表示。在下面，我们将说明字节对编码是如何工作的。\n",
    "\n",
    "首先，我们将符号词表初始化为所有英文小写字符、特殊的词尾符号`'_'`和特殊的未知符号`'[UNK]'`。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/Functions.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.util.stream.*;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String[] symbols = {\n",
    "            \"a\", \"b\", \"c\", \"d\", \"e\", \"f\", \"g\", \"h\", \"i\", \"j\", \"k\", \"l\", \"m\", \"n\", \"o\", \"p\",\n",
    "            \"q\", \"r\", \"s\", \"t\", \"u\", \"v\", \"w\", \"x\", \"y\", \"z\", \"_\", \"[UNK]\"\n",
    "        };"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "因为我们不考虑跨越词边界的符号对，所以我们只需要一个字典`rawTokenFreqs`将词映射到数据集中的频率（出现次数）。注意，特殊符号`'_'`被附加到每个词的尾部，以便我们可以容易地从输出符号序列（例如，“a_all er_man”）恢复单词序列（例如，“a_all er_man”）。由于我们仅从单个字符和特殊符号的词开始合并处理，所以在每个词（词典`tokenFreqs`的键）内的每对连续字符之间插入空格。换句话说，空格是词中符号之间的分隔符。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "HashMap<String, Integer> rawTokenFreqs = new HashMap<>();\n",
    "rawTokenFreqs.put(\"fast_\", 4);\n",
    "rawTokenFreqs.put(\"faster_\", 3);\n",
    "rawTokenFreqs.put(\"tall_\", 5);\n",
    "rawTokenFreqs.put(\"taller_\", 4);\n",
    "\n",
    "HashMap<String, Integer> tokenFreqs = new HashMap<>();\n",
    "for (Map.Entry<String, Integer> e : rawTokenFreqs.entrySet()) {\n",
    "    String token = e.getKey();\n",
    "    tokenFreqs.put(String.join(\" \", token.split(\"\")), rawTokenFreqs.get(token));\n",
    "}\n",
    "\n",
    "tokenFreqs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们定义以下`getMaxFreqPair`函数，其返回词内最频繁的连续符号对，其中词来自输入词典`tokenFreqs`的键。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static Pair<String, String> getMaxFreqPair(HashMap<String, Integer> tokenFreqs) {\n",
    "    HashMap<Pair<String, String>, Integer> pairs = new HashMap<>();\n",
    "    for (Map.Entry<String, Integer> e : tokenFreqs.entrySet()) {\n",
    "        // Key of 'pairs' is a tuple of two consecutive symbols\n",
    "        String token = e.getKey();\n",
    "        Integer freq = e.getValue();\n",
    "        String[] symbols = token.split(\" \");\n",
    "        for (int i = 0; i < symbols.length - 1; i++) {\n",
    "            pairs.put(\n",
    "                    new Pair<>(symbols[i], symbols[i + 1]),\n",
    "                    pairs.getOrDefault(new Pair<>(symbols[i], symbols[i + 1]), 0) + freq);\n",
    "        }\n",
    "    }\n",
    "    int max = 0; // Key of `pairs` with the max value\n",
    "    Pair<String, String> maxFreqPair = null;\n",
    "    for (Map.Entry<Pair<String, String>, Integer> pair : pairs.entrySet()) {\n",
    "        if (max < pair.getValue()) {\n",
    "            max = pair.getValue();\n",
    "            maxFreqPair = pair.getKey();\n",
    "        }\n",
    "    }\n",
    "    return maxFreqPair;\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "作为基于连续符号频率的贪心方法，字节对编码将使用以下`mergeSymbols`函数来合并最频繁的连续符号对以产生新符号。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static Pair<HashMap<String, Integer>, String[]> mergeSymbols(\n",
    "        Pair<String, String> maxFreqPair, HashMap<String, Integer> tokenFreqs) {\n",
    "    ArrayList<String> symbols = new ArrayList<>();\n",
    "    symbols.add(maxFreqPair.getKey() + maxFreqPair.getValue());\n",
    "\n",
    "    HashMap<String, Integer> newTokenFreqs = new HashMap<>();\n",
    "    for (Map.Entry<String, Integer> e : tokenFreqs.entrySet()) {\n",
    "        String token = e.getKey();\n",
    "        String newToken =\n",
    "                token.replace(\n",
    "                        maxFreqPair.getKey() + \" \" + maxFreqPair.getValue(),\n",
    "                        maxFreqPair.getKey() + \"\" + maxFreqPair.getValue());\n",
    "        newTokenFreqs.put(newToken, tokenFreqs.get(token));\n",
    "    }\n",
    "    return new Pair(newTokenFreqs, symbols.toArray(new String[symbols.size()]));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在，我们对词典`tokenFreqs`的键迭代地执行字节对编码算法。在第一次迭代中，最频繁的连续符号对是`'t'`和`'a'`，因此字节对编码将它们合并以产生新符号`'ta'`。在第二次迭代中，字节对编码继续合并`'ta'`和`'l'`以产生另一个新符号`'tal'`。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int numMerges = 10;\n",
    "for (int i = 0; i < numMerges; i++) {\n",
    "    Pair<String, String> maxFreqPair = getMaxFreqPair(tokenFreqs);\n",
    "    Pair<HashMap<String, Integer>, String[]> pair =\n",
    "            mergeSymbols(maxFreqPair, tokenFreqs);\n",
    "    tokenFreqs = pair.getKey();\n",
    "    symbols =\n",
    "            Stream.concat(Arrays.stream(symbols), Arrays.stream(pair.getValue()))\n",
    "                    .toArray(String[]::new);\n",
    "    System.out.println(\n",
    "            \"合并 #\"\n",
    "                    + (i + 1)\n",
    "                    + \": (\"\n",
    "                    + maxFreqPair.getKey()\n",
    "                    + \", \"\n",
    "                    + maxFreqPair.getValue()\n",
    "                    + \")\");\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在字节对编码的10次迭代之后，我们可以看到列表`symbols`现在又包含10个从其他符号迭代合并而来的符号。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Arrays.toString(symbols)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对于在词典`rawTokenFreqs`的键中指定的同一数据集，作为字节对编码算法的结果，数据集中的每个词现在被子词“fast_”、“fast”、“er_”、“tall_”和“tall”分割。例如，单词“fast er_”和“tall er_”分别被分割为“fast er_”和“tall er_”。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenFreqs.keySet()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "请注意，字节对编码的结果取决于正在使用的数据集。我们还可以使用从一个数据集学习的子词来切分另一个数据集的单词。作为一种贪心方法，下面的`segmentBPE`函数尝试将单词从输入参数`symbols`分成可能最长的子词。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static List<String> segmentBPE(String[] tokens, String[] symbols) {\n",
    "    List<String> outputs = new ArrayList<>();\n",
    "    for (String token : tokens) {\n",
    "        int start = 0;\n",
    "        int end = token.length();\n",
    "        ArrayList<String> curOutput = new ArrayList<>();\n",
    "        // Segment token with the longest possible subwords from symbols\n",
    "        while (start < token.length() && start < end) {\n",
    "            if (Arrays.asList(symbols).contains(token.substring(start, end))) {\n",
    "                curOutput.add(token.substring(start, end));\n",
    "                start = end;\n",
    "                end = token.length();\n",
    "            } else {\n",
    "                end -= 1;\n",
    "            }\n",
    "        }\n",
    "        if (start < tokens.length) {\n",
    "            curOutput.add(\"[UNK]\");\n",
    "        }\n",
    "        String temp = \"\";\n",
    "        for (String s : curOutput) {\n",
    "            temp += s + \" \";\n",
    "        }\n",
    "        outputs.add(temp.trim());\n",
    "    }\n",
    "    return outputs;\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们使用列表`symbols`中的子词（从前面提到的数据集学习）来表示另一个数据集的`tokens`。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String[] tokens = new String[] {\"tallest_\", \"fatter_\"};\n",
    "System.out.println(segmentBPE(tokens, symbols));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* fastText模型提出了一种子词嵌入方法：基于word2vec中的跳元模型，它将中心词表示为其子词向量之和。\n",
    "* 字节对编码执行训练数据集的统计分析，以发现词内的公共符号。作为一种贪心方法，字节对编码迭代地合并最频繁的连续符号对。\n",
    "* 子词嵌入可以提高稀有词和词典外词的表示质量。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 例如，英语中大约有$3\\times 10^8$种可能的$6$-元组。子词太多会有什么问题呢？如何解决这个问题？提示:请参阅fastText论文第3.2节末尾 :cite:`Bojanowski.Grave.Joulin.ea.2017`。\n",
    "1. 如何在连续词袋模型的基础上设计一个子词嵌入模型？\n",
    "1. 要获得大小为$m$的词表，当初始符号词表大小为$n$时，需要多少合并操作？\n",
    "1. 如何扩展字节对编码的思想来提取短语？\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
